{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example will walk you throught the basic usage of PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, there is a unified import of `import promptbench as pb` that easily imports the package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import promptbench as pb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset\n",
    "\n",
    "First, PromptBench supports easy load of datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All supported datasets: \n",
      "['sst2', 'cola', 'qqp', 'mnli', 'mnli_matched', 'mnli_mismatched', 'qnli', 'wnli', 'rte', 'mrpc', 'mmlu', 'squad_v2', 'un_multi', 'iwslt2017', 'math', 'bool_logic', 'valid_parentheses', 'gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking', 'last_letter_concat', 'numersense', 'qasc']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'content': \"it 's a charming and often affecting journey . \", 'label': 1},\n",
       " {'content': 'unflinchingly bleak and desperate ', 'label': 0},\n",
       " {'content': 'allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . ',\n",
       "  'label': 1},\n",
       " {'content': \"the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . \",\n",
       "  'label': 1},\n",
       " {'content': \"it 's slow -- very , very slow . \", 'label': 0}]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# print all supported datasets in promptbench\n",
    "print('All supported datasets: ')\n",
    "print(pb.SUPPORTED_DATASETS)\n",
    "\n",
    "# load a dataset, sst2, for instance.\n",
    "# if the dataset is not available locally, it will be downloaded automatically.\n",
    "dataset = pb.DatasetLoader.load_dataset(\"sst2\")\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"mmlu\")\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"un_multi\")\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"iwslt2017\", [\"ar-en\", \"de-en\", \"en-ar\"])\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"math\", \"algebra__linear_1d\")\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"bool_logic\")\n",
    "# dataset = pb.DatasetLoader.load_dataset(\"valid_parenthesesss\")\n",
    "\n",
    "# print the first 5 examples\n",
    "dataset[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models\n",
    "\n",
    "Then, you can easily load LLM models via promptbench."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All supported models: \n",
      "['google/flan-t5-large', 'llama2-7b', 'llama2-7b-chat', 'llama2-13b', 'llama2-13b-chat', 'llama2-70b', 'llama2-70b-chat', 'phi-1.5', 'phi-2', 'palm', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-1106-preview', 'gpt-3.5-turbo-1106', 'vicuna-7b', 'vicuna-13b', 'vicuna-13b-v1.3', 'google/flan-ul2', 'gemini-pro']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "# print all supported models in promptbench\n",
    "print('All supported models: ')\n",
    "print(pb.SUPPORTED_MODELS)\n",
    "\n",
    "# load a model, flan-t5-large, for instance.\n",
    "model = pb.LLMModel(model='google/flan-t5-large', max_new_tokens=10, temperature=0.0001, device='cuda')\n",
    "# model = pb.LLMModel(model='llama2-13b-chat', max_new_tokens=10, temperature=0.0001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct prompts\n",
    "\n",
    "Prompts are the key interaction interface to LLMs. You can easily construct a prompt by call the Prompt API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prompt API supports a list, so you can pass multiple prompts at once.\n",
    "prompts = pb.Prompt([\"Classify the sentence as positive or negative: {content}\",\n",
    "                     \"Determine the emotion of the following sentence as positive or negative: {content}\"\n",
    "                     ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may need to define the projection function for the model output.\n",
    "Since the output format defined in your prompts may be different from the model output.\n",
    "For example, for sst2 dataset, the label are '0' and '1' to represent 'negative' and 'positive'.\n",
    "But the model output is 'negative' and 'positive'.\n",
    "So we need to define a projection function to map the model output to the label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proj_func(pred):\n",
    "    mapping = {\n",
    "        \"positive\": 1,\n",
    "        \"negative\": 0\n",
    "    }\n",
    "    return mapping.get(pred, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform evaluation using prompts, datasets, and models\n",
    "\n",
    "Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/872 [00:00<?, ?it/s]/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0001` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "2023-12-24 05:02:56.792134: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-12-24 05:02:56.792188: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-12-24 05:02:56.793225: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2023-12-24 05:02:56.799581: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-12-24 05:02:57.481660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
      "  0%|          | 1/872 [00:02<29:19,  2.02s/it]/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0001` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "100%|██████████| 872/872 [00:58<00:00, 14.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.947, Classify the sentence as positive or negative: {content}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 872/872 [00:56<00:00, 15.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.947, Determine the emotion of the following sentence as positive or negative: {content}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "for prompt in prompts:\n",
    "    preds = []\n",
    "    labels = []\n",
    "    for data in tqdm(dataset):\n",
    "        # process input\n",
    "        input_text = pb.InputProcess.basic_format(prompt, data)\n",
    "        label = data['label']\n",
    "        raw_pred = model(input_text)\n",
    "        # process output\n",
    "        pred = pb.OutputProcess.cls(raw_pred, proj_func)\n",
    "        preds.append(pred)\n",
    "        labels.append(label)\n",
    "    \n",
    "    # evaluate\n",
    "    score = pb.Eval.compute_cls_accuracy(preds, labels)\n",
    "    print(f\"{score:.3f}, {prompt}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "promptbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
